
import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


from torchvision import datasets
from torchvision import transforms


##-----------------GAN model-----------------##

generator = nn.Sequential(
            nn.Linear(100, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 28 * 28),
            nn.Tanh()
        )

discriminator = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )


## Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(mnist_dataset, batch_size=128, shuffle=True)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

adversarial_loss = nn.BCELoss()

for epoch in range(501):
    for i, (imgs, _) in enumerate(data_loader):
        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1)
        fake = torch.zeros(imgs.size(0), 1)

        # Configure input
        real_imgs = imgs.view(imgs.size(0), -1)
        z = torch.randn(imgs.size(0), 100)

        # Train Generator
        optimizer_G.zero_grad()
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        if epoch in [100,300,500]:
            PATH = 'state_dict_gene_' + str(epoch) + '.pt'

        # Save
            torch.save(generator.state_dict(), PATH)
        print(epoch)




##-----------------MNIST Classifier model-----------------##

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


RANDOM_SEED = 1
LEARNING_RATE = 0.001
BATCH_SIZE = 128
NUM_EPOCHS = 10
# Architecture
NUM_FEATURES = 28*28
NUM_CLASSES = 10

GRAYSCALE = True
class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes, grayscale):
        self.inplanes = 64
        if grayscale:
            in_dim = 1
        else:
            in_dim = 3
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(in_dim, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2. / n)**.5)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # because MNIST is already 1x1 here:
        # disable avg pooling
        #x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        logits = self.fc(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas



def resnet18(num_classes):
    """Constructs a ResNet-18 model."""
    model = ResNet(block=BasicBlock,
                   layers=[2, 2, 2, 2],
                   num_classes=NUM_CLASSES,
                   grayscale=GRAYSCALE)
    return model



# Load the classifier
model_class = resnet18(10)
Base = torch.load('classifier.pth')
model_class.load_state_dict(Base)
model_class.eval()



##-----------------Generate images-----------------##

# Load three GANs model
Parameter100 = torch.load('/Discriminative_Estimation_TV/state_dict_gene_100.pt')
Parameter300 = torch.load('Discriminative_Estimation_TV/state_dict_gene_300.pt')
Parameter500 = torch.load('Discriminative_Estimation_TV/state_dict_gene_500.pt')

generator.load_state_dict(Parameter300, strict=False)
generator.eval()


# Plot images
import matplotlib.pyplot as plt
torch.manual_seed(3)
z = torch.randn(100, 100)
gen_imgs = (generator(z).detach().cpu()+1)/2
torch.max(model_class(gen_imgs.view(-1,1,28,28))[0],1)
fig, axs = plt.subplots(5, 5)
cnt = 0
for r in range(5):
    for c in range(5):
        plt.subplots_adjust(left=0.01, right=0.99, top=0.99, bottom=0.01)
        axs[r, c].imshow(gen_imgs.view(-1,1,28,28)[cnt][0], cmap='gray')
        axs[r, c].axis('off')
        cnt += 1
        plt.show()





# GAN 100 images
z_train = torch.randn(100000, 100)
z_test = torch.randn(100000, 100)
generator.load_state_dict(Parameter100, strict=False)
generator.eval()
Image_100_train = (generator(z_train).detach().cpu()+1)/2
Image_100_test = (generator(z_test).detach().cpu()+1)/2
Y_100_train = torch.max(model_class(Image_100_train.view(-1,1,28,28))[1],axis=1)[1]
Y_100_test = torch.max(model_class(Image_100_test.view(-1,1,28,28))[1],axis=1)[1]
torch.save({'image':Image_100_train,'label':Y_100_train}, 'GAN_100_train.pth')
torch.save({'image':Image_100_test,'label':Y_100_test}, 'GAN_100_test.pth')


# GAN 300 images
generator.load_state_dict(Parameter300, strict=False)
generator.eval()
Image_300_train = (generator(z_train).detach().cpu()+1)/2
Image_300_test = (generator(z_test).detach().cpu()+1)/2
Y_300_train = torch.max(model_class(Image_300_train.view(-1,1,28,28))[1],axis=1)[1]
Y_300_test = torch.max(model_class(Image_300_test.view(-1,1,28,28))[1],axis=1)[1]
torch.save({'image':Image_300_train,'label':Y_300_train}, 'GAN_300_train.pth')
torch.save({'image':Image_300_test,'label':Y_300_test}, 'GAN_300_test.pth')



# GAN 500 images
generator.load_state_dict(Parameter500, strict=False)
generator.eval()
Image_500_train = (generator(z_train).detach().cpu()+1)/2
Image_500_test = (generator(z_test).detach().cpu()+1)/2
Y_500_train = torch.max(model_class(Image_500_train.view(-1,1,28,28))[1],axis=1)[1]
Y_500_test = torch.max(model_class(Image_500_test.view(-1,1,28,28))[1],axis=1)[1]
torch.save({'image':Image_500_train,'label':Y_500_train}, 'GAN_500_train.pth')
torch.save({'image':Image_500_test,'label':Y_500_test}, 'GAN_500_test.pth')


"""
from collections import Counter
Counter(Y_100_train.tolist())
Counter(Y_100_test.tolist())
Counter(Y_300_train.tolist())
Counter(Y_300_test.tolist())
Counter(Y_500_train.tolist())
Counter(Y_500_test.tolist())
"""
